In [1]:
import os, sys

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn

from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision import transforms, utils

from torch import optim
from torch import cuda
from torchvision import transforms

from sklearn.model_selection import StratifiedKFold

from PIL import Image
import cv2

import random
import math
import gc, pickle
from tqdm import tqdm_notebook as tqdm

import warnings
warnings.filterwarnings('ignore')
In [2]:
base_path = "../input/anime-faces/data/"
img_path = os.listdir(base_path)
img_path = [path for path in img_path if 'png' in path]
In [3]:
def read_img(path):
    base_path = "../input/anime-faces/data/"
    img = cv2.imread(f"{base_path}{path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img
In [4]:
len(img_path)
Out[4]:
21551
In [5]:
plt.figure(figsize=(30,10))
sample_path = random.sample(img_path, 15)
for i, path in enumerate(sample_path):
    plt.subplot(3,5,i+1)
    
    img = read_img(path)
    
    img = cv2.normalize(img, None, alpha=0, beta=1.0, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    plt.imshow(img)
plt.show()
In [6]:
sorted([
    f"../input/dcgan-generate-anime-faces-4/{a}" for a in\
            os.listdir('../input/dcgan-generate-anime-faces-4/') if "gene" in a])
Out[6]:
['../input/dcgan-generate-anime-faces-4/generator10110.pt',
 '../input/dcgan-generate-anime-faces-4/generator1348.pt',
 '../input/dcgan-generate-anime-faces-4/generator2359.pt',
 '../input/dcgan-generate-anime-faces-4/generator337.pt',
 '../input/dcgan-generate-anime-faces-4/generator3370.pt',
 '../input/dcgan-generate-anime-faces-4/generator4381.pt',
 '../input/dcgan-generate-anime-faces-4/generator5392.pt',
 '../input/dcgan-generate-anime-faces-4/generator6403.pt',
 '../input/dcgan-generate-anime-faces-4/generator7414.pt',
 '../input/dcgan-generate-anime-faces-4/generator8425.pt',
 '../input/dcgan-generate-anime-faces-4/generator9436.pt']
In [7]:
class Generator(nn.Module):
    def __init__(self, img_size, latent_dim, channels):
        super(Generator, self).__init__()

        self.init_size = img_size // 8
        self.conv_init_dim = 512
        
        self.l1 = nn.Sequential(nn.Linear(latent_dim, self.conv_init_dim * self.init_size ** 2))
        self.channels = channels

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(self.conv_init_dim),
            nn.Upsample(scale_factor=2),
            
            nn.Conv2d(self.conv_init_dim, self.conv_init_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(self.conv_init_dim, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Upsample(scale_factor=2),
            
            nn.Conv2d(self.conv_init_dim, self.conv_init_dim//2, 3, stride=1, padding=1),
            nn.BatchNorm2d(self.conv_init_dim//2, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(self.conv_init_dim//2, self.conv_init_dim//2, 3, stride=1, padding=1),
            nn.BatchNorm2d(self.conv_init_dim//2, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Upsample(scale_factor=2),
            
            nn.Conv2d(self.conv_init_dim//2, self.conv_init_dim//2, 3, stride=1, padding=1),
            nn.BatchNorm2d(self.conv_init_dim//2, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(self.conv_init_dim//2, self.conv_init_dim//4, 3, stride=1, padding=1),
            nn.BatchNorm2d(self.conv_init_dim//4, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(self.conv_init_dim//4, self.channels, 3, stride=1, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], self.conv_init_dim, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
img_size = 64
latent_dim = 128
channels=3
b_size=40

paths = ['../input/dcgan-generate-anime-faces-4/generator337.pt',
         '../input/dcgan-generate-anime-faces-4/generator1348.pt',
         '../input/dcgan-generate-anime-faces-4/generator2359.pt',
         '../input/dcgan-generate-anime-faces-4/generator3370.pt',
         '../input/dcgan-generate-anime-faces-4/generator4381.pt',
         '../input/dcgan-generate-anime-faces-4/generator5392.pt',
         '../input/dcgan-generate-anime-faces-4/generator6403.pt',
         '../input/dcgan-generate-anime-faces-4/generator7414.pt',
         '../input/dcgan-generate-anime-faces-4/generator8425.pt',
         '../input/dcgan-generate-anime-faces-4/generator9436.pt',
         '../input/dcgan-generate-anime-faces-4/generator10110.pt',]

generator = Generator(img_size, latent_dim, channels)

z = np.random.normal(0, 1, (b_size, latent_dim)).tolist()
z = torch.FloatTensor(z)

for path in paths:
    state_dict = torch.load(path)
    generator.load_state_dict(state_dict)
    
    imgs = generator(z)
    print(imgs.size())
    
    imgs = imgs.detach()
    plt.figure(figsize=(30,14))
    grid = utils.make_grid(imgs, nrow=10)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title(path.split('/')[-1])
    
    plt.show()
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
In [8]:
generator = Generator(img_size, latent_dim, channels)
for _ in range(4):
    z = []
    z_1 = np.random.normal(0, 1, (latent_dim))
    z_2 = np.random.normal(0, 1, (latent_dim))
    z_3 = np.random.normal(0, 1, (latent_dim))
    for p in np.linspace(0,1,20):
        z.append(((p-1)*z_1+p*z_2).tolist())
    for p in np.linspace(0,1,20):
        z.append(((1-p)*z_2+p*z_3).tolist())
    z = torch.FloatTensor(z)

    for path in paths[-1:]:
        state_dict = torch.load(path)
        generator.load_state_dict(state_dict)

        imgs = generator(z)
        print(imgs.size())

        imgs = imgs.detach()
        plt.figure(figsize=(30,14))
        grid = utils.make_grid(imgs, nrow=10)
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.title(path.split('/')[-1])

        plt.show()
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])
torch.Size([40, 3, 64, 64])